{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6bc2e9e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "numpy      : 1.23.4\n",
      "scipy      : 1.9.3\n",
      "torch      : 1.12.1\n",
      "pomegranate: 0.14.8\n",
      "\n",
      "Compiler    : GCC 11.2.0\n",
      "OS          : Linux\n",
      "Release     : 4.15.0-197-generic\n",
      "Machine     : x86_64\n",
      "Processor   : x86_64\n",
      "CPU cores   : 8\n",
      "Architecture: 64bit\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import numpy\n",
    "import scipy\n",
    "import torch\n",
    "\n",
    "from torchegranate.distributions import *\n",
    "\n",
    "numpy.random.seed(0)\n",
    "numpy.set_printoptions(suppress=True)\n",
    "\n",
    "%load_ext watermark\n",
    "%watermark -m -n -p numpy,scipy,torch,pomegranate"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7dd56360",
   "metadata": {},
   "source": [
    "### Normal w/ Diagonal Covariance Distributions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5efcc291",
   "metadata": {},
   "outputs": [],
   "source": [
    "n, d = 100000, 500\n",
    "\n",
    "X = torch.randn(n, d)\n",
    "Xn = X.numpy()\n",
    "\n",
    "mus = torch.randn(d)\n",
    "covs = torch.abs(torch.randn(d))\n",
    "stds = torch.sqrt(covs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1c3325d9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "143 ms ± 12.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
      "227 ms ± 14.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
      "1.12 s ± 18.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
     ]
    }
   ],
   "source": [
    "%timeit Normal(mus, covs, covariance_type='diag').log_probability(X)\n",
    "%timeit torch.distributions.Normal(mus, stds).log_prob(X).sum(dim=-1)\n",
    "%timeit scipy.stats.norm.logpdf(Xn, mus, stds).sum(axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd46b4b0",
   "metadata": {},
   "source": [
    "### Normal w/ Full Covariance Distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "07fab284",
   "metadata": {},
   "outputs": [],
   "source": [
    "d0 = Normal().fit(X)\n",
    "\n",
    "mu, cov = d0.means, d0.covs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "194d7679",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "211 ms ± 19.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
      "205 ms ± 22.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
      "765 ms ± 36.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
     ]
    }
   ],
   "source": [
    "%timeit Normal(mu, cov).log_probability(X)\n",
    "%timeit torch.distributions.MultivariateNormal(mu, cov).log_prob(X).sum(dim=-1)\n",
    "%timeit scipy.stats.multivariate_normal.logpdf(Xn, mu, cov).sum(axis=-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a5adc6a",
   "metadata": {},
   "source": [
    "### Exponential Distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f70bb98d",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = torch.abs(torch.randn(n, d))\n",
    "Xn = X.numpy()\n",
    "\n",
    "means = torch.abs(torch.randn(d))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ab3d0af3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "150 ms ± 1.31 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
      "89 ms ± 3.47 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
      "1.36 s ± 86.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
     ]
    }
   ],
   "source": [
    "%timeit Exponential(means).log_probability(X)\n",
    "%timeit torch.distributions.Exponential(means).log_prob(X)\n",
    "%timeit scipy.stats.expon.logpdf(X, means)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5108fce5",
   "metadata": {},
   "source": [
    "### Gamma Distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "06865521",
   "metadata": {},
   "outputs": [],
   "source": [
    "shapes = torch.abs(torch.randn(d))\n",
    "rates = torch.abs(torch.randn(d))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "2459f3f0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "270 ms ± 9.09 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
      "250 ms ± 30.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
      "2.67 s ± 75.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
     ]
    }
   ],
   "source": [
    "%timeit Gamma(shapes, rates).log_probability(X)\n",
    "%timeit torch.distributions.Gamma(shapes, rates).log_prob(X)\n",
    "%timeit scipy.stats.gamma.logpdf(X, shapes, rates)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c81f8f06",
   "metadata": {},
   "source": [
    "### Bernoulli Distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7cee5e63",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = torch.tensor(numpy.random.choice(2, size=(n, d)), dtype=torch.float32)\n",
    "probs = torch.mean(X, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "0f697993",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "181 ms ± 8.04 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
      "419 ms ± 20.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
      "3.78 s ± 66.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
     ]
    }
   ],
   "source": [
    "%timeit Bernoulli(probs).log_probability(X)\n",
    "%timeit torch.distributions.Bernoulli(probs).log_prob(X)\n",
    "%timeit scipy.stats.bernoulli.logpmf(X, probs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
